// NOTE: loss is accumulated per-block. The caller is responsible for
// ensuring that the *loss array is big enough, and for summing it to
// obtain the final combined loss
template <typename T>
__device__ void hinge_loss_forward(const T * const __restrict__ pred,
                                   const T * const __restrict__ labl,
                                   int num,
                                         T * const __restrict__ loss) {
  __shared__ T local_loss[THREADS_PER_BLOCK_X];

  const T zero = static_cast<T>(0);
  const T one = static_cast<T>(1);

  const int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx >= num) {
    // out of bounds, set local loss to zero so that we can safely accumulate later
    local_loss[threadIdx.x] = 0;
  } else {
    const T the_pred = pred[idx];
    const T the_labl = labl[idx];
    const T y = the_pred * the_labl;
    local_loss[threadIdx.x] = y<one ? one-y : zero;
  }

  __syncthreads();
  if (0 == threadIdx.x) {
    // thread 0 does in-block accumulation
    T total_local_loss = zero;
    for (int i = 0; i < THREADS_PER_BLOCK_X; ++i)
      total_local_loss += local_loss[i];
    //atomicAdd(loss, static_cast<float>(total_local_loss));
    loss[blockIdx.x] = total_local_loss;
  }
}

template <typename T>
__device__ void hinge_loss_backward(const T * const __restrict__ pred,
                                    const T * const __restrict__ labl,
                                          T * const __restrict__ gradient1,
                                          T * const __restrict__ gradient2,
                                          int num, T neg_weight)
{
  const T zero = static_cast<T>(0);
  const T one = static_cast<T>(1);

  const int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx >= num) {
    return;
  }

  if(labl[idx]*pred[idx] < one) {
    if(gradient1) gradient1[idx] = neg_weight * labl[idx];
    if(gradient2) gradient2[idx] = neg_weight * pred[idx];
  } else {
    if(gradient1) gradient1[idx] = zero;
    if(gradient2) gradient2[idx] = zero;
  }
}

extern "C" {
  __global__ void hinge_loss_forward_float(const float * const __restrict__ pred,
                                           const float * const __restrict__ labl,
                                           int num,
                                                float * const __restrict__ loss) {
    hinge_loss_forward(pred, labl, num, loss);
  }
  __global__ void hinge_loss_forward_double(const double * const __restrict__ pred,
                                            const double * const __restrict__ labl,
                                            int num,
                                                  double * const __restrict__ loss) {
    hinge_loss_forward(pred, labl, num, loss);
  }
  __global__ void hinge_loss_backward_float(const float * const __restrict__ pred,
                                            const float * const __restrict__ labl,
                                                  float * const __restrict__ gradient1,
                                                  float * const __restrict__ gradient2,
                                                  int num, float neg_weight) {
    hinge_loss_backward(pred, labl, gradient1, gradient2, num, neg_weight);
  }
  __global__ void hinge_loss_backward_double(const double * const __restrict__ pred,
                                             const double * const __restrict__ labl,
                                                   double * const __restrict__ gradient1,
                                                   double * const __restrict__ gradient2,
                                                   int num, double neg_weight) {
    hinge_loss_backward(pred, labl, gradient1, gradient2, num, neg_weight);
  }
}

// vim: ft=cuda
